package com.wy.repository;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.security.jackson2.CoreJackson2Module;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2UserCode;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationServerJackson2Module;
import org.springframework.security.web.jackson2.WebServletJackson2Module;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

/**
 * 基于redis的授权管理服务
 *
 * @author 飞花梦影
 * @date 2024-11-07 09:30:56
 * @git {@link https://github.com/dreamFlyingFlower}
 */
@Slf4j
@Service
@RequiredArgsConstructor
public class RedisOAuth2AuthorizationService implements OAuth2AuthorizationService {

	private final RegisteredClientRepository registeredClientRepository;

	private final RedisOAuth2AuthorizationRepository oAuth2AuthorizationRepository;

	private final static ObjectMapper MAPPER = new ObjectMapper();

	static {
		// 初始化序列化配置
		ClassLoader classLoader = RedisOAuth2AuthorizationService.class.getClassLoader();
		// 加载security提供的Modules
		List<Module> modules = SecurityJackson2Modules.getModules(classLoader);
		MAPPER.registerModules(modules);
		// 加载Security提供的Module
		MAPPER.registerModule(new CoreJackson2Module());
		// 加载Authorization Server提供的Module
		MAPPER.registerModule(new OAuth2AuthorizationServerJackson2Module());
		// 添加Security web提供的Module
		MAPPER.registerModule(new WebServletJackson2Module());
		// 添加OAuth2 Client提供的Module
		MAPPER.registerModule(new OAuth2ClientJackson2Module());
	}

	@Override
	public void save(OAuth2Authorization authorization) {
		Optional<RedisOAuth2Authorization> existingAuthorization =
				oAuth2AuthorizationRepository.findById(authorization.getId());

		// 如果已存在则删除后再保存
		existingAuthorization.map(RedisOAuth2Authorization::getId).ifPresent(oAuth2AuthorizationRepository::deleteById);

		// 过期时间,默认30分钟过期
		long maxTimeout = 30L;
		// 所有code的过期时间,方便计算最大值
		List<Instant> expiresAtList = new ArrayList<>();

		RedisOAuth2Authorization entity = toEntity(authorization);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getAuthorizationCodeExpiresAt()).ifPresent(expiresAtList::add);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getAccessTokenExpiresAt()).ifPresent(expiresAtList::add);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getRefreshTokenExpiresAt()).ifPresent(expiresAtList::add);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getOidcIdTokenExpiresAt()).ifPresent(expiresAtList::add);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getUserCodeExpiresAt()).ifPresent(expiresAtList::add);

		// 如果有过期时间就存入
		Optional.ofNullable(entity.getDeviceCodeExpiresAt()).ifPresent(expiresAtList::add);

		// 获取最大的日期
		Optional<Instant> maxInstant = expiresAtList.stream().max(Comparator.comparing(Instant::getEpochSecond));
		if (maxInstant.isPresent()) {
			// 计算时间差
			Duration between = Duration.between(Instant.now(), maxInstant.get());
			// 转为分钟
			maxTimeout = between.toMinutes();
		}

		// 设置过期时间
		entity.setTimeout(maxTimeout);

		// 保存至redis
		oAuth2AuthorizationRepository.save(entity);
	}

	@Override
	public void remove(OAuth2Authorization authorization) {
		Assert.notNull(authorization, "authorization cannot be null");
		oAuth2AuthorizationRepository.deleteById(authorization.getId());
	}

	@Override
	public OAuth2Authorization findById(String id) {
		Assert.hasText(id, "id cannot be empty");
		return oAuth2AuthorizationRepository.findById(id).map(this::toObject).orElse(null);
	}

	@Override
	public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
		Assert.hasText(token, "token cannot be empty");

		Optional<RedisOAuth2Authorization> result;

		if (tokenType == null) {
			result = oAuth2AuthorizationRepository.findByState(token)
					.or(() -> oAuth2AuthorizationRepository.findByAuthorizationCodeValue(token))
					.or(() -> oAuth2AuthorizationRepository.findByAccessTokenValue(token))
					.or(() -> oAuth2AuthorizationRepository.findByOidcIdTokenValue(token))
					.or(() -> oAuth2AuthorizationRepository.findByRefreshTokenValue(token))
					.or(() -> oAuth2AuthorizationRepository.findByUserCodeValue(token))
					.or(() -> oAuth2AuthorizationRepository.findByDeviceCodeValue(token));
		} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
			result = oAuth2AuthorizationRepository.findByState(token);
		} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
			result = oAuth2AuthorizationRepository.findByAuthorizationCodeValue(token);
		} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
			result = oAuth2AuthorizationRepository.findByAccessTokenValue(token);
		} else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) {
			result = oAuth2AuthorizationRepository.findByOidcIdTokenValue(token);
		} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
			result = oAuth2AuthorizationRepository.findByRefreshTokenValue(token);
		} else if (OAuth2ParameterNames.USER_CODE.equals(tokenType.getValue())) {
			result = oAuth2AuthorizationRepository.findByUserCodeValue(token);
		} else if (OAuth2ParameterNames.DEVICE_CODE.equals(tokenType.getValue())) {
			result = oAuth2AuthorizationRepository.findByDeviceCodeValue(token);
		} else {
			result = Optional.empty();
		}

		return result.map(this::toObject).orElse(null);
	}

	/**
	 * 将redis中存储的类型转为框架所需的类型
	 *
	 * @param entity redis中存储的类型
	 * @return 框架所需的类型
	 */
	private OAuth2Authorization toObject(RedisOAuth2Authorization entity) {
		RegisteredClient registeredClient = this.registeredClientRepository.findById(entity.getRegisteredClientId());
		if (registeredClient == null) {
			throw new DataRetrievalFailureException("The RegisteredClient with id '" + entity.getRegisteredClientId()
					+ "' was not found in the RegisteredClientRepository.");
		}

		OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
				.id(entity.getId())
				.principalName(entity.getPrincipalName())
				.authorizationGrantType(resolveAuthorizationGrantType(entity.getAuthorizationGrantType()))
				.authorizedScopes(StringUtils.commaDelimitedListToSet(entity.getAuthorizedScopes()))
				.attributes(attributes -> attributes.putAll(parseMap(entity.getAttributes())));
		if (entity.getState() != null) {
			builder.attribute(OAuth2ParameterNames.STATE, entity.getState());
		}

		if (entity.getAuthorizationCodeValue() != null) {
			OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(entity.getAuthorizationCodeValue(),
					entity.getAuthorizationCodeIssuedAt(), entity.getAuthorizationCodeExpiresAt());
			builder.token(authorizationCode,
					metadata -> metadata.putAll(parseMap(entity.getAuthorizationCodeMetadata())));
		}

		if (entity.getAccessTokenValue() != null) {
			OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
					entity.getAccessTokenValue(), entity.getAccessTokenIssuedAt(), entity.getAccessTokenExpiresAt(),
					StringUtils.commaDelimitedListToSet(entity.getAccessTokenScopes()));
			builder.token(accessToken, metadata -> metadata.putAll(parseMap(entity.getAccessTokenMetadata())));
		}

		if (entity.getRefreshTokenValue() != null) {
			OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(entity.getRefreshTokenValue(),
					entity.getRefreshTokenIssuedAt(), entity.getRefreshTokenExpiresAt());
			builder.token(refreshToken, metadata -> metadata.putAll(parseMap(entity.getRefreshTokenMetadata())));
		}

		if (entity.getOidcIdTokenValue() != null) {
			OidcIdToken idToken = new OidcIdToken(entity.getOidcIdTokenValue(), entity.getOidcIdTokenIssuedAt(),
					entity.getOidcIdTokenExpiresAt(), parseMap(entity.getOidcIdTokenClaims()));
			builder.token(idToken, metadata -> metadata.putAll(parseMap(entity.getOidcIdTokenMetadata())));
		}

		if (entity.getUserCodeValue() != null) {
			OAuth2UserCode userCode = new OAuth2UserCode(entity.getUserCodeValue(), entity.getUserCodeIssuedAt(),
					entity.getUserCodeExpiresAt());
			builder.token(userCode, metadata -> metadata.putAll(parseMap(entity.getUserCodeMetadata())));
		}

		if (entity.getDeviceCodeValue() != null) {
			OAuth2DeviceCode deviceCode = new OAuth2DeviceCode(entity.getDeviceCodeValue(),
					entity.getDeviceCodeIssuedAt(), entity.getDeviceCodeExpiresAt());
			builder.token(deviceCode, metadata -> metadata.putAll(parseMap(entity.getDeviceCodeMetadata())));
		}

		return builder.build();
	}

	/**
	 * 将框架所需的类型转为redis中存储的类型
	 *
	 * @param authorization 框架所需的类型
	 * @return redis中存储的类型
	 */
	private RedisOAuth2Authorization toEntity(OAuth2Authorization authorization) {
		RedisOAuth2Authorization entity = new RedisOAuth2Authorization();
		entity.setId(authorization.getId());
		entity.setRegisteredClientId(authorization.getRegisteredClientId());
		entity.setPrincipalName(authorization.getPrincipalName());
		entity.setAuthorizationGrantType(authorization.getAuthorizationGrantType().getValue());
		entity.setAuthorizedScopes(StringUtils.collectionToDelimitedString(authorization.getAuthorizedScopes(), ","));
		entity.setAttributes(writeMap(authorization.getAttributes()));
		entity.setState(authorization.getAttribute(OAuth2ParameterNames.STATE));

		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
				authorization.getToken(OAuth2AuthorizationCode.class);
		setTokenValues(authorizationCode, entity::setAuthorizationCodeValue, entity::setAuthorizationCodeIssuedAt,
				entity::setAuthorizationCodeExpiresAt, entity::setAuthorizationCodeMetadata);

		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getToken(OAuth2AccessToken.class);
		setTokenValues(accessToken, entity::setAccessTokenValue, entity::setAccessTokenIssuedAt,
				entity::setAccessTokenExpiresAt, entity::setAccessTokenMetadata);
		if (accessToken != null && accessToken.getToken().getScopes() != null) {
			entity.setAccessTokenScopes(
					StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","));
		}

		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getToken(OAuth2RefreshToken.class);
		setTokenValues(refreshToken, entity::setRefreshTokenValue, entity::setRefreshTokenIssuedAt,
				entity::setRefreshTokenExpiresAt, entity::setRefreshTokenMetadata);

		OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
		setTokenValues(oidcIdToken, entity::setOidcIdTokenValue, entity::setOidcIdTokenIssuedAt,
				entity::setOidcIdTokenExpiresAt, entity::setOidcIdTokenMetadata);
		if (oidcIdToken != null) {
			entity.setOidcIdTokenClaims(writeMap(oidcIdToken.getClaims()));
		}

		OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
		setTokenValues(userCode, entity::setUserCodeValue, entity::setUserCodeIssuedAt, entity::setUserCodeExpiresAt,
				entity::setUserCodeMetadata);

		OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class);
		setTokenValues(deviceCode, entity::setDeviceCodeValue, entity::setDeviceCodeIssuedAt,
				entity::setDeviceCodeExpiresAt, entity::setDeviceCodeMetadata);

		return entity;
	}

	/**
	 * 设置token的值
	 *
	 * @param token Token实例
	 * @param tokenValueConsumer set方法
	 * @param issuedAtConsumer set方法
	 * @param expiresAtConsumer set方法
	 * @param metadataConsumer set方法
	 */
	private void setTokenValues(OAuth2Authorization.Token<?> token, Consumer<String> tokenValueConsumer,
			Consumer<Instant> issuedAtConsumer, Consumer<Instant> expiresAtConsumer,
			Consumer<String> metadataConsumer) {
		if (token != null) {
			OAuth2Token oAuth2Token = token.getToken();
			tokenValueConsumer.accept(oAuth2Token.getTokenValue());
			issuedAtConsumer.accept(oAuth2Token.getIssuedAt());
			expiresAtConsumer.accept(oAuth2Token.getExpiresAt());
			metadataConsumer.accept(writeMap(token.getMetadata()));
		}
	}

	/**
	 * 处理授权申请时的 GrantType
	 *
	 * @param authorizationGrantType 授权申请时的 GrantType
	 * @return AuthorizationGrantType的实例
	 */
	private static AuthorizationGrantType resolveAuthorizationGrantType(String authorizationGrantType) {
		if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.AUTHORIZATION_CODE;
		} else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.CLIENT_CREDENTIALS;
		} else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.REFRESH_TOKEN;
		} else if (AuthorizationGrantType.DEVICE_CODE.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.DEVICE_CODE;
		}
		// Custom authorization grant type
		return new AuthorizationGrantType(authorizationGrantType);
	}

	/**
	 * 将json转为map
	 *
	 * @param data json
	 * @return map对象
	 */
	private Map<String, Object> parseMap(String data) {
		try {
			return MAPPER.readValue(data, new TypeReference<>() {
			});
		} catch (Exception ex) {
			log.info("Exception json：{}", data);
			throw new IllegalArgumentException(ex.getMessage(), ex);
		}
	}

	/**
	 * 将map对象转为json字符串
	 *
	 * @param metadata map对象
	 * @return json字符串
	 */
	private String writeMap(Map<String, Object> metadata) {
		try {
			return MAPPER.writeValueAsString(metadata);
		} catch (Exception ex) {
			throw new IllegalArgumentException(ex.getMessage(), ex);
		}
	}
}